import torch


def to_device(obj, device):
    if isinstance(obj, (tuple, list)):
        obj = [to_device(x, device) for x in obj]
    elif isinstance(obj, dict):
        new_obj = dict()
        for k, v in obj.items():
            new_obj[k] = to_device(v, device)
        obj = new_obj
    elif torch.is_tensor(obj):
        obj = obj.to(device)
    return obj
